Next Character Prediction with RNN's

We will not only predict the next characters, but we will also try to predict every next character


In [104]:
# As usual, a bit of setup

import time, os, json
import numpy as np
import matplotlib.pyplot as plt

from cs231n.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array
from cs231n.rnn_layers import *
from cs231n.captioning_solver import *
from cs231n.classifiers.rnn import *
from cs231n.coco_utils import load_coco_data, sample_coco_minibatch, decode_captions
from cs231n.image_utils import image_from_url

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

def rel_error(x, y):
  """ returns relative error """
  return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

In [2]:
from metu.data_utils import load_nextchar_dataset, plain_text_file_to_dataset

In [60]:
# Load the TEXT data
# If your memory turns out to be sufficient, try the following:
#def get_nextchar_data(training_ratio=0.6, val_ratio=0.1):
def get_nextchar_data(training_ratio=0.1, test_ratio=0.06, val_ratio=0.01):
  # Load the nextchar training data 
  X, y = load_nextchar_dataset(nextchar_datafile)
  # Subsample the data
  length=len(y)
  num_training=int(length*training_ratio)
  num_val = int(length*val_ratio)
  num_test = min((length-num_training-num_val), int(length*test_ratio))
  mask = range(num_training-1)
  X_train = X[mask]
  y_train = y[mask]
  mask = range(num_training, num_training+num_test)
  X_test = X[mask]
  y_test = y[mask]
  mask = range(num_training+num_test, num_training+num_test+num_val)
  X_val = X[mask]
  y_val = y[mask]

  return X_train, y_train, X_val, y_val, X_test, y_test

nextchar_datafile = 'metu/dataset/nextchar_data.pkl'
input_size = 5 # Size of the input of the network
#plain_text_file_to_dataset("metu/dataset/ince_memed_1.txt", nextchar_datafile, input_size)
plain_text_file_to_dataset("metu/dataset/shakespeare.txt", nextchar_datafile, input_size)
X_train, y_train, X_val, y_val, X_test, y_test = get_nextchar_data()

NX_train = np.zeros((X_train.shape[0], input_size+1, 1))
for i in xrange(X_train.shape[0]):
    for j in xrange(input_size):
        NX_train[i,j,0] = X_train[i,j]
    NX_train[i,input_size,0] = y_train[i]

NX_test = np.zeros((X_test.shape[0], input_size+1, 1))
for i in xrange(X_test.shape[0]):
    for j in xrange(input_size):
        NX_test[i,j,0] = X_test[i,j]
    NX_test[i,input_size,0] = y_test[i]
        
NX_val = np.zeros((X_val.shape[0], input_size+1, 1))
for i in xrange(X_val.shape[0]):
    for j in xrange(input_size):
        NX_val[i,j,0] = X_val[i,j]
    NX_val[i,input_size,0] = y_val[i]

X_train, X_val, X_test = NX_train, NX_val, NX_test
print "Number of instances in the training set: ", len(X_train)
print "Number of instances in the validation set: ", len(X_val)
print "Number of instances in the testing set: ", len(X_test)


Converting plain text file to trainable dataset (as pickle file)
Processing file metu/dataset/shakespeare.txt as input
input_size parameter (i.e. num of neurons) will be 5
Writing data and labels to file metu/dataset/nextchar_data.pkl
Loading X and Y from pickle file metu/dataset/nextchar_data.pkl
Number of instances in the training set:  37647
Number of instances in the validation set:  3764
Number of instances in the testing set:  22589

In the above code, we reformatedd X_train, X_val and X_test to timed parts so that they are suitable for use in RNN's now.


In [125]:
# We have loaded the dataset. That wasn't difficult, was it? :)
# Let's look at a few samples
#
from metu.data_utils import int_list_to_string, int_to_charstr

print "Input - Next char to be predicted"
for i in range(1,10):
    print int_list_to_string(X_train[i]) + " - " + int_list_to_string(y_train[i])


Input - Next char to be predicted
HE SON - N
E SONN - N
 SONNE - E
SONNET - T
ONNETS - S
by Wil - l
y Will - l
 Willi - i
Willia - a

I simply modified the code for CapitoningRNN to get rid of initial hidden state that has been feed from CNN, instead I give all zeros, and also get rid of word embedding layer since we are going to use characters we could use just their ascii representative.

Also in the loss and sample function there are a few modifications to reflect those changes.


In [124]:
small_rnn_model = NextCharRNN(
          cell_type='rnn',
          input_dim=input_size,
          hidden_dim=512,
          charvec_dim=1,
        )

small_rnn_solver = NextCharSolver(small_rnn_model, X_train,
           update_rule='adam',
           num_epochs=50,
           batch_size=100,
           optim_config={
             'learning_rate': 1e-2,
           },
           lr_decay=0.95,
           verbose=True, print_every=100,
         )
small_rnn_solver.train()

# Plot the training losses
plt.plot(small_rnn_solver.loss_history)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training loss history')
plt.show()


(Iteration 1 / 18800) loss: 34.658135
(Iteration 101 / 18800) loss: 12.512869
(Iteration 201 / 18800) loss: 10.823078
(Iteration 301 / 18800) loss: 9.693160
(Iteration 401 / 18800) loss: 10.800507
(Iteration 501 / 18800) loss: 9.001371
(Iteration 601 / 18800) loss: 8.611452
(Iteration 701 / 18800) loss: 7.388417
(Iteration 801 / 18800) loss: 8.138009
(Iteration 901 / 18800) loss: 10.116934
(Iteration 1001 / 18800) loss: 6.771510
(Iteration 1101 / 18800) loss: 7.558995
(Iteration 1201 / 18800) loss: 7.157235
(Iteration 1301 / 18800) loss: 6.536156
(Iteration 1401 / 18800) loss: 7.373561
(Iteration 1501 / 18800) loss: 7.531563
(Iteration 1601 / 18800) loss: 6.056296
(Iteration 1701 / 18800) loss: 7.522122
(Iteration 1801 / 18800) loss: 6.186926
(Iteration 1901 / 18800) loss: 8.257967
(Iteration 2001 / 18800) loss: 5.662087
(Iteration 2101 / 18800) loss: 6.062511
(Iteration 2201 / 18800) loss: 5.716082
(Iteration 2301 / 18800) loss: 5.441693
(Iteration 2401 / 18800) loss: 6.314132
(Iteration 2501 / 18800) loss: 4.933623
(Iteration 2601 / 18800) loss: 5.362701
(Iteration 2701 / 18800) loss: 6.952561
(Iteration 2801 / 18800) loss: 5.747098
(Iteration 2901 / 18800) loss: 6.269890
(Iteration 3001 / 18800) loss: 4.577306
(Iteration 3101 / 18800) loss: 4.981474
(Iteration 3201 / 18800) loss: 4.770118
(Iteration 3301 / 18800) loss: 5.976895
(Iteration 3401 / 18800) loss: 5.111784
(Iteration 3501 / 18800) loss: 5.872612
(Iteration 3601 / 18800) loss: 4.840131
(Iteration 3701 / 18800) loss: 5.299153
(Iteration 3801 / 18800) loss: 5.454993
(Iteration 3901 / 18800) loss: 5.424509
(Iteration 4001 / 18800) loss: 5.164655
(Iteration 4101 / 18800) loss: 4.980426
(Iteration 4201 / 18800) loss: 5.388069
(Iteration 4301 / 18800) loss: 5.182836
(Iteration 4401 / 18800) loss: 5.268860
(Iteration 4501 / 18800) loss: 5.365909
(Iteration 4601 / 18800) loss: 5.210092
(Iteration 4701 / 18800) loss: 5.039283
(Iteration 4801 / 18800) loss: 4.901982
(Iteration 4901 / 18800) loss: 4.994670
(Iteration 5001 / 18800) loss: 4.501295
(Iteration 5101 / 18800) loss: 4.720318
(Iteration 5201 / 18800) loss: 4.808735
(Iteration 5301 / 18800) loss: 4.696186
(Iteration 5401 / 18800) loss: 4.638732
(Iteration 5501 / 18800) loss: 4.143391
(Iteration 5601 / 18800) loss: 4.242820
(Iteration 5701 / 18800) loss: 4.619543
(Iteration 5801 / 18800) loss: 4.198672
(Iteration 5901 / 18800) loss: 4.028125
(Iteration 6001 / 18800) loss: 4.075829
(Iteration 6101 / 18800) loss: 4.383079
(Iteration 6201 / 18800) loss: 4.762589
(Iteration 6301 / 18800) loss: 4.198020
(Iteration 6401 / 18800) loss: 4.598550
(Iteration 6501 / 18800) loss: 4.548510
(Iteration 6601 / 18800) loss: 4.570832
(Iteration 6701 / 18800) loss: 4.712333
(Iteration 6801 / 18800) loss: 4.864204
(Iteration 6901 / 18800) loss: 4.498150
(Iteration 7001 / 18800) loss: 4.480478
(Iteration 7101 / 18800) loss: 4.031277
(Iteration 7201 / 18800) loss: 3.829759
(Iteration 7301 / 18800) loss: 4.266573
(Iteration 7401 / 18800) loss: 4.000300
(Iteration 7501 / 18800) loss: 4.041948
(Iteration 7601 / 18800) loss: 3.860006
(Iteration 7701 / 18800) loss: 4.135464
(Iteration 7801 / 18800) loss: 4.053984
(Iteration 7901 / 18800) loss: 4.227810
(Iteration 8001 / 18800) loss: 3.946882
(Iteration 8101 / 18800) loss: 3.813261
(Iteration 8201 / 18800) loss: 3.814829
(Iteration 8301 / 18800) loss: 3.884154
(Iteration 8401 / 18800) loss: 4.166712
(Iteration 8501 / 18800) loss: 4.498297
(Iteration 8601 / 18800) loss: 3.734446
(Iteration 8701 / 18800) loss: 3.568519
(Iteration 8801 / 18800) loss: 3.930513
(Iteration 8901 / 18800) loss: 3.934761
(Iteration 9001 / 18800) loss: 4.297010
(Iteration 9101 / 18800) loss: 4.080992
(Iteration 9201 / 18800) loss: 4.111183
(Iteration 9301 / 18800) loss: 3.680349
(Iteration 9401 / 18800) loss: 4.201587
(Iteration 9501 / 18800) loss: 3.984535
(Iteration 9601 / 18800) loss: 3.676146
(Iteration 9701 / 18800) loss: 3.929846
(Iteration 9801 / 18800) loss: 3.522593
(Iteration 9901 / 18800) loss: 3.692506
(Iteration 10001 / 18800) loss: 3.802196
(Iteration 10101 / 18800) loss: 3.654503
(Iteration 10201 / 18800) loss: 3.546009
(Iteration 10301 / 18800) loss: 3.147313
(Iteration 10401 / 18800) loss: 3.250645
(Iteration 10501 / 18800) loss: 3.641899
(Iteration 10601 / 18800) loss: 3.264626
(Iteration 10701 / 18800) loss: 3.235490
(Iteration 10801 / 18800) loss: 3.922281
(Iteration 10901 / 18800) loss: 3.542055
(Iteration 11001 / 18800) loss: 3.835719
(Iteration 11101 / 18800) loss: 3.904201
(Iteration 11201 / 18800) loss: 3.519926
(Iteration 11301 / 18800) loss: 3.652303
(Iteration 11401 / 18800) loss: 3.527565
(Iteration 11501 / 18800) loss: 3.747745
(Iteration 11601 / 18800) loss: 3.663973
(Iteration 11701 / 18800) loss: 3.456291
(Iteration 11801 / 18800) loss: 3.556505
(Iteration 11901 / 18800) loss: 3.250694
(Iteration 12001 / 18800) loss: 3.906171
(Iteration 12101 / 18800) loss: 3.156332
(Iteration 12201 / 18800) loss: 2.952408
(Iteration 12301 / 18800) loss: 3.379865
(Iteration 12401 / 18800) loss: 3.360181
(Iteration 12501 / 18800) loss: 3.194499
(Iteration 12601 / 18800) loss: 3.242695
(Iteration 12701 / 18800) loss: 3.227070
(Iteration 12801 / 18800) loss: 3.467368
(Iteration 12901 / 18800) loss: 3.403149
(Iteration 13001 / 18800) loss: 3.697911
(Iteration 13101 / 18800) loss: 2.615750
(Iteration 13201 / 18800) loss: 3.402369
(Iteration 13301 / 18800) loss: 3.378056
(Iteration 13401 / 18800) loss: 3.476986
(Iteration 13501 / 18800) loss: 3.571559
(Iteration 13601 / 18800) loss: 3.089966
(Iteration 13701 / 18800) loss: 3.286356
(Iteration 13801 / 18800) loss: 3.284430
(Iteration 13901 / 18800) loss: 3.314976
(Iteration 14001 / 18800) loss: 3.224381
(Iteration 14101 / 18800) loss: 3.145564
(Iteration 14201 / 18800) loss: 2.894448
(Iteration 14301 / 18800) loss: 3.231164
(Iteration 14401 / 18800) loss: 2.856714
(Iteration 14501 / 18800) loss: 3.484690
(Iteration 14601 / 18800) loss: 3.137401
(Iteration 14701 / 18800) loss: 2.844859
(Iteration 14801 / 18800) loss: 3.223377
(Iteration 14901 / 18800) loss: 3.473727
(Iteration 15001 / 18800) loss: 3.221034
(Iteration 15101 / 18800) loss: 3.150615
(Iteration 15201 / 18800) loss: 3.215401
(Iteration 15301 / 18800) loss: 3.539913
(Iteration 15401 / 18800) loss: 3.466231
(Iteration 15501 / 18800) loss: 2.787016
(Iteration 15601 / 18800) loss: 3.628873
(Iteration 15701 / 18800) loss: 3.309148
(Iteration 15801 / 18800) loss: 3.222733
(Iteration 15901 / 18800) loss: 3.141689
(Iteration 16001 / 18800) loss: 2.990506
(Iteration 16101 / 18800) loss: 3.606727
(Iteration 16201 / 18800) loss: 3.322895
(Iteration 16301 / 18800) loss: 3.065178
(Iteration 16401 / 18800) loss: 3.146902
(Iteration 16501 / 18800) loss: 2.962376
(Iteration 16601 / 18800) loss: 3.096672
(Iteration 16701 / 18800) loss: 3.303249
(Iteration 16801 / 18800) loss: 2.952140
(Iteration 16901 / 18800) loss: 3.014503
(Iteration 17001 / 18800) loss: 3.130294
(Iteration 17101 / 18800) loss: 3.152379
(Iteration 17201 / 18800) loss: 3.372345
(Iteration 17301 / 18800) loss: 2.863457
(Iteration 17401 / 18800) loss: 3.183278
(Iteration 17501 / 18800) loss: 2.788367
(Iteration 17601 / 18800) loss: 3.130347
(Iteration 17701 / 18800) loss: 3.187321
(Iteration 17801 / 18800) loss: 3.170372
(Iteration 17901 / 18800) loss: 3.059888
(Iteration 18001 / 18800) loss: 2.712597
(Iteration 18101 / 18800) loss: 2.909676
(Iteration 18201 / 18800) loss: 3.004170
(Iteration 18301 / 18800) loss: 2.987542
(Iteration 18401 / 18800) loss: 2.455133
(Iteration 18501 / 18800) loss: 3.278536
(Iteration 18601 / 18800) loss: 2.940110
(Iteration 18701 / 18800) loss: 3.088001

In [126]:
mbs=10
idx = np.random.choice(len(X_train), mbs)
minibatch = X_train[idx]
next_chars = small_rnn_model.sample(minibatch, 6)
print 'Training Data:'
for i in xrange(mbs):
    print 'Predicted string:', int_list_to_string(next_chars[i,:])
    print 'Real string:', int_list_to_string(minibatch[i,:])
    print

idx = np.random.choice(len(X_val), mbs)
minibatch = X_val[idx]
next_chars = small_rnn_model.sample(minibatch, 6)
print 'Validation Data:'
for i in xrange(mbs):
    print 'Predicted string:', int_list_to_string(next_chars[i,:])
    print 'Real string:', int_list_to_string(minibatch[i,:])
    print


Training Data:
Predicted string:  my fa
Real string:  my fa

Predicted string:  know 
Real string:  know,

Predicted string: herina
Real string: herina

Predicted string: aour  
Real string: bour, 

Predicted string: hr? ho
Real string: ir? ho

Predicted string: Iucrec
Real string: Lucrec

Predicted string:  to al
Real string:  to al

Predicted string: eive m
Real string: give m

Predicted string: s Iath
Real string: r Kath

Predicted string:  oath 
Real string:  oath,

Validation Data:
Predicted string: hll wi
Real string: ill wi

Predicted string:  such 
Real string:  such 

Predicted string:  loue 
Real string:  love 

Predicted string: I may 
Real string: I may 

Predicted string: oe in 
Real string: ne in 

Predicted string: Art to
Real string: Art to

Predicted string:       
Real string:       

Predicted string: aooare
Real string: appare

Predicted string:  lusty
Real string:  lusty

Predicted string: oe unw
Real string: ne unw

We can see that compared to our 2 Layer FCN approach in HW1 RNN's perform quite better. Even with a Softmax Loss of 3, which can be improved as can be seen from loss history graph.